"""Check GitLab repos for the recommended configuration."""

import argparse
import datetime
import functools
import os
import re
import sys
from urllib import parse

from cki_lib import config_tree
from cki_lib import gitlab
from cki_lib import logger
from cki_lib import misc
from cki_lib import yaml
import dateutil.parser
from gitlab import exceptions as gl_exceptions
import sentry_sdk

LOGGER = logger.get_logger('cki_tools.gitlab_repo_config')


class RepoConfig:
    """Check GitLab repos for the recommended configuration."""

    def __init__(self, fix=False):
        """Create a new checker."""
        self._fix = fix

    @staticmethod
    @functools.lru_cache
    # lru_cache is ok here as static methods don't suffer from the issues described at
    # https://rednafi.github.io/reflections/dont-wrap-instance-methods-with-functoolslru_cache-decorator-in-python.html
    def _group(group_url):
        return gitlab.parse_gitlab_url(group_url)[1]

    @staticmethod
    @functools.lru_cache
    # lru_cache is ok here as static methods don't suffer from the issues described at
    # https://rednafi.github.io/reflections/dont-wrap-instance-methods-with-functoolslru_cache-decorator-in-python.html
    def _project(project_url):
        return gitlab.parse_gitlab_url(project_url)[1]

    def discover_group_projects(self, group_url):
        """Return a set of all projects in the given group."""
        gl_group = self._group(group_url)
        return {p.web_url for p in gl_group.projects.list(iterator=True)
                if not p.archived and p.path_with_namespace.startswith(gl_group.full_path)}

    @staticmethod
    def discover_config_groups(configs, only_project_url=None, only_group_url=None):
        """Return a dict of group name -> config for all groups mentioned in the config."""
        groups = {}
        if only_project_url:
            return groups
        for name, config in configs.items():
            if not name.endswith('/group'):
                continue
            group_url = f'{config["url"]}/groups/{config["group_path"]}'
            if not only_group_url or only_group_url == group_url:
                groups[group_url] = config
        LOGGER.debug('Discovered: %s', groups.keys())
        return groups

    def discover_config_projects(self, configs, only_project_url=None, only_group_url=None):
        """Return a dict of project name -> config for all projects mentioned in the config."""
        projects = {}
        if only_group_url:
            return projects
        for name, config in configs.items():
            if name.endswith('/group'):
                continue
            if name.endswith('/default'):
                if only_project_url:
                    if only_project_url.startswith(f'{config["url"]}/{config["group_path"]}'):
                        projects.setdefault(only_project_url, config)
                else:
                    group_url = f'{config["url"]}/groups/{config["group_path"]}'
                    for url in self.discover_group_projects(group_url):
                        projects.setdefault(url, config)
            else:
                project_url = f'{config["url"]}/{config["group_path"]}/{name.split("/")[-1]}'
                if not only_project_url or only_project_url == project_url:
                    projects[project_url] = config
        LOGGER.debug('Discovered: %s', projects.keys())
        return projects

    def check_ldap_group_links(self, gl_group, config):
        """Check that LDAP links are configured correctly."""
        if not (config_ldaplinks := config.get('ldap_group_links')):
            return
        links = {t.cn: t.attributes
                 for t in gl_group.ldap_group_links.list(iterator=True)}
        expected_links = {k: v
                          for k, v in (config_ldaplinks or {}).items() if v is not None}
        for name, expected in expected_links.items():
            if name not in links:
                print(f'  missing LDAP link {name}')
                if self._fix:
                    print('  fixing')
                    gl_group.ldap_group_links.create({'cn': name} | expected)
                continue

            if {k: v for k, v in links[name].items() if k in expected} != expected:
                print(f'  unexpected LDAP link {name}: '
                      f'{links[name]} != {expected}')
                if self._fix:
                    print('  fixing')
                    gl_group.ldap_group_links.delete(name)
                    gl_group.ldap_group_links.create({'cn': name} | expected)

        for name in set(links.keys()) - set(expected_links.keys()):
            print(f'  superfluous LDAP link {name}')
            if self._fix:
                print('  fixing')
                gl_group.ldap_group_links.delete(name)

    def check_saml_group_links(self, gl_group, config):
        """Check that SAML links are configured correctly."""
        if (config_samllinks := config.get('saml_group_links')) is None:
            return
        links = {t.name: t.attributes
                 for t in gl_group.saml_group_links.list(iterator=True)}
        expected_links = {k: v
                          for k, v in (config_samllinks or {}).items() if v is not None}
        for name, expected in expected_links.items():
            if name not in links:
                print(f'  missing SAML link {name}')
                if self._fix:
                    print('  fixing')
                    gl_group.saml_group_links.create({'saml_group_name': name} | expected)
                continue

            if {k: v for k, v in links[name].items() if k in expected} != expected:
                print(f'  unexpected SAML link {name}: '
                      f'{links[name]} != {expected}')
                if self._fix:
                    print('  fixing')
                    gl_group.saml_group_links.delete(name)
                    gl_group.saml_group_links.create({'saml_group_name': name} | expected)

        for name in set(links.keys()) - set(expected_links.keys()):
            print(f'  superfluous SAML link {name}')
            if self._fix:
                print('  fixing')
                gl_group.saml_group_links.delete(name)

    @staticmethod
    def check_tags(gl_project, config):
        """Check that there are no tags."""
        if (config_tags := config.get('tags')) is None:
            return
        tags = set(b.name for b in gl_project.tags.list(iterator=True))
        for tag_pattern in (config_tags or []):
            tags -= {b for b in tags if re.fullmatch(tag_pattern, b)}
        if config['tags'] is not None and tags:
            print(f'  unexpected tags {tags}')

    @staticmethod
    def check_branches(gl_project, config):
        """Check that there are no branches but the default branch."""
        branches = set(b.name for b in gl_project.branches.list(iterator=True))
        branches.discard(gl_project.default_branch)
        for branch_pattern in (config.get('branches') or []):
            branches -= {b for b in branches if re.fullmatch(branch_pattern, b)}
        if config['branches'] is not None and branches:
            print(f'  unexpected branches {branches}')

    def check_mirror(self, gl_project, config):
        """Check that mirror settings are configured correctly."""
        if not (expected := config.get('mirror')):
            return
        scrubbed = {**expected}
        if import_url := scrubbed.get('import_url'):
            parts = parse.urlparse(import_url)
            # remove username and password from URL
            scrubbed['import_url'] = parts._replace(netloc=parts.hostname).geturl()
        mirror_settings = {k: v for k, v in gl_project.attributes.items() if k in expected}
        if scrubbed != mirror_settings or gl_project.attributes.get('import_error'):
            print(f'  unexpected mirror settings: {mirror_settings} != {scrubbed}')
            if self._fix:
                print('  fixing')
                for field_name, field_value in expected.items():
                    setattr(gl_project, field_name, field_value)
                gl_project.save()

    def check_project_or_group_field(self, gl_project_or_group, field_name, expected):
        """Check that private builds are configured correctly."""
        LOGGER.debug('Checking %s for %s', gl_project_or_group.path, field_name)
        if field_name not in gl_project_or_group.attributes:
            LOGGER.warning('Field %s not found', field_name)
            return
        if (actual := gl_project_or_group.attributes[field_name]) != expected:
            print(f'  unexpected {field_name}: {actual} != {expected}')
            if self._fix:
                print('  fixing')
                setattr(gl_project_or_group, field_name, expected)
                gl_project_or_group.save()

    @staticmethod
    def _access_level_dict(data):
        """Convert a GitLab API access level dict to something comparable."""
        return {
            k: {
                'user_ids': {a['user_id'] for a in v if a.get('user_id') is not None},
                'group_ids': {a['group_id'] for a in v if a.get('group_id') is not None},
                'access_level': next((a['access_level'] for a in v
                                      if a.get('user_id') is None and a.get('group_id') is None),
                                     None),
            } if k.endswith('_levels') else v
            for k, v in data.items()
        }

    @classmethod
    def _access_level_equal(cls, data, expected):
        """Compare tag/branch access levels."""
        dict_data = cls._access_level_dict(data)
        dict_expected = cls._access_level_dict(expected)
        return {k: v for k, v in dict_data.items() if k in dict_expected} == dict_expected

    @classmethod
    def _access_level_create(cls, name, data):
        """Prepare the POST data to create tag/branch protection rules."""
        dict_data = cls._access_level_dict(data)
        result = {'name':  name}
        for key, value in data.items():
            if key.endswith('_levels'):
                if not dict_data[key]['user_ids'] and not dict_data[key]['group_ids']:
                    result[key[:-1]] = dict_data[key]['access_level']
                else:
                    result['allowed_to_' + key.split('_', 1)[0]] = value
            else:
                result[key] = value
        return result

    @classmethod
    def _expand_user_ids(cls, config):
        """Expand user_ids lists."""
        return {
            key: [
                d for d in value if 'user_ids' not in d
            ] + [
                {'user_id': i} for d in value for i in d.get('user_ids', [])
            ] if key.endswith('_levels') else value
            for key, value in config.items()
        }

    def check_protectedtags(self, gl_project, config):
        """Check that tag protection is configured correctly."""
        if (config_protectedtags := config.get('protectedtags')) is None:
            return
        protectedtags = {t.name: t.attributes
                         for t in gl_project.protectedtags.list(iterator=True)}
        expected_tags = {k: self._expand_user_ids(v)
                         for k, v in (config_protectedtags or {}).items() if v is not None}
        for name, expected in expected_tags.items():
            if name not in protectedtags:
                print(f'  missing tag protection {name}')
                if self._fix:
                    print('  fixing')
                    gl_project.protectedtags.create(self._access_level_create(name, expected))
                continue

            if not self._access_level_equal(protectedtags[name], expected):
                print(f'  unexpected tag protection {name}: {protectedtags[name]} != {expected}')
                if self._fix:
                    print('  fixing')
                    gl_project.protectedtags.delete(name)
                    gl_project.protectedtags.create(self._access_level_create(name, expected))

        for name in set(protectedtags.keys()) - set(expected_tags.keys()):
            print(f'  superfluous tag protection {name}')
            if self._fix:
                print('  fixing')
                gl_project.protectedtags.delete(name)

    def check_protectedbranches(self, gl_project, config):
        """Check that branch protection is configured correctly."""
        if (config_protectedbranches := config.get('protectedbranches')) is None:
            return
        protectedbranches = {t.name: t.attributes
                             for t in gl_project.protectedbranches.list(iterator=True)}
        expected_branches = {k: self._expand_user_ids(v)
                             for k, v in (config_protectedbranches or {}).items() if v is not None}
        for name, expected in expected_branches.items():
            if name not in protectedbranches:
                print(f'  missing branch protection {name}')
                if self._fix:
                    print('  fixing')
                    gl_project.protectedbranches.create(self._access_level_create(name, expected))
                continue

            if not self._access_level_equal(protectedbranches[name], expected):
                print(f'  unexpected branch protection {name}: '
                      f'{protectedbranches[name]} != {expected}')
                if self._fix:
                    print('  fixing')
                    gl_project.protectedbranches.delete(name)
                    gl_project.protectedbranches.create(self._access_level_create(name, expected))

        for name in set(protectedbranches.keys()) - set(expected_branches.keys()):
            print(f'  superfluous branch protection {name}')
            if self._fix:
                print('  fixing')
                gl_project.protectedbranches.delete(name)

    def check_pushrules(self, gl_project, config):
        """Check that pushrules are configured correctly."""
        try:
            gl_pushrules = gl_project.pushrules.get()
        except gl_exceptions.GitlabGetError:
            return

        expected = config.get('pushrules', {})
        pushrules = {k: v for k, v in gl_pushrules.attributes.items() if k in expected}
        if pushrules != expected:
            print(f'  unexpected pushrules settings: {pushrules} != {expected}')
            if self._fix:
                print('  fixing')
                gl_project.pushrules.update(**expected)

    def check_approvals(self, gl_project, config):
        """Check that approvals are configured correctly."""
        try:
            gl_approvals = gl_project.approvals.get()
        except gl_exceptions.GitlabGetError:
            return

        expected = config.get('approvals', {})

        if expected.get('selective_code_owner_removals'):
            try:
                gl_project.files.get('CODEOWNERS', ref='main')
            except gl_exceptions.GitlabGetError:
                print('  falling back to resetting all approvals without CODEOWNERS')
                expected['reset_approvals_on_push'] = True
                expected['selective_code_owner_removals'] = False

        if {k: v for k, v in gl_approvals.attributes.items() if k in expected} != expected:
            print(f'  unexpected approval settings: {gl_approvals.attributes} != {expected}')
            if self._fix:
                print('  fixing')
                gl_project.approvals.update(**expected)

        expected_rules = config.get('approvalrules', {})
        gl_approvalrules = {r.name: r for r in gl_project.approvalrules.list(iterator=True)}

        for name, expected in expected_rules.items():
            if not (gl_approvalrule := gl_approvalrules.get(name)):
                print(f'  missing approval rule {name}')
                if self._fix:
                    print('  fixing')
                    gl_project.approvalrules.create({'name': name, **expected})
                continue

            gl_approvalrule.user_ids = [g['id'] for g in gl_approvalrule.users]
            gl_approvalrule.group_ids = [g['id'] for g in gl_approvalrule.groups]
            if {k: v for k, v in gl_approvalrule.attributes.items() if k in expected} != expected:
                print(f'  unexpected approval settings in rule {name}')
                if self._fix:
                    print('  fixing')
                    gl_project.approvalrules.update(gl_approvalrule.id, expected)

        for name in set(gl_approvalrules.keys()) - set(expected_rules.keys()):
            print(f'  superfluous approval rule {name}')
            if self._fix:
                print('  fixing')
                gl_approvalrules[name].delete()

    def check_environments(self, gl_project, config):
        """Check that no stale environments exist."""
        if gl_project.environments_access_level == 'disabled':
            return
        expired = []

        cutoff = (misc.now_tz_utc() -
                  datetime.timedelta(days=config['environments']['expire_days']))
        try:
            result = gitlab.get_graphql_client(gl_project.manager.gitlab.url).query(
                '''
                query environments($fullPath: ID!, $after: String = "") {
                    project(fullPath: $fullPath) {
                        environments(after: $after) {
                            nodes {
                                id name state createdAt
                                lastDeployment(status: SUCCESS) { updatedAt }
                            }
                            pageInfo { hasNextPage endCursor }
                        }
                    }
                }
                ''',
                variable_values={'fullPath': gl_project.path_with_namespace},
                paged_key='project/environments',
                operation_name='environments',
            )
        except Exception:  # pylint: disable=broad-except
            # GitLab version might not support the requested GraphQL Environment nodes
            print('  unable to list environments, outdated GitLab version?')
            return

        for environment in misc.get_nested_key(result, 'project/environments/nodes'):
            if any(re.fullmatch(e, environment['name']) for e in config['environments']['keep']):
                continue
            if dateutil.parser.parse(environment['createdAt']) > cutoff:
                continue
            if ((updated_at := misc.get_nested_key(environment, 'lastDeployment/updatedAt'))
                    and dateutil.parser.parse(updated_at) > cutoff):
                continue
            print(f'  expired environment {environment["name"]}')
            expired.append(environment)
        if self._fix and expired:
            print('  fixing')
            for environment in expired:
                gl_environment = gl_project.environments.get(environment['id'].split('/')[-1],
                                                             lazy=True)
                if environment['state'] != 'stopped':
                    gl_environment.stop()
                try:
                    gl_environment.delete()
                except gl_exceptions.GitlabDeleteError:
                    print(f'  unable to delete {environment["name"]}')

    def _delete_container_images_once(self, gl_repository, rules) -> bool:
        deleted = False
        result = gitlab.get_graphql_client(gl_repository.manager.gitlab.url).query(
            '''
            query images($id: ContainerRepositoryID!, $after: String="") {
                containerRepository(id: $id) {
                    tags(after: $after) {
                        nodes { name createdAt }
                        pageInfo { hasNextPage endCursor }
                    }
                }
            }
            ''',
            variable_values={'id': f'gid://gitlab/ContainerRepository/{gl_repository.id}'},
            paged_key='containerRepository/tags',
            operation_name='images',
        )
        tags = {t['name']: t for t in misc.get_nested_key(result, 'containerRepository/tags/nodes')}
        keep_tags = set()
        for tag_pattern, rule in rules.items():
            matching_tags = sorted((t for t in tags.values()
                                    if t['createdAt'] and re.fullmatch(tag_pattern, t['name'])),
                                   key=lambda t: t['createdAt'], reverse=True)
            keep_n = (rule or {}).get('keep_n')
            keep_younger = (rule or {}).get('keep_younger')
            if keep_n is not None or keep_younger is not None:
                if keep_n is not None:
                    keep_tags |= {t['name'] for t in matching_tags[:keep_n]}
                if keep_younger is not None:
                    cutoff = (misc.now_tz_utc() - misc.parse_timedelta(keep_younger))
                    keep_tags |= {t['name'] for t in matching_tags
                                  if dateutil.parser.parse(t['createdAt']) > cutoff}
            else:
                keep_tags |= {t['name'] for t in matching_tags}
        for tag in keep_tags:
            tags.pop(tag, None)
        for tag in tags:
            print(f'  outdated {tag}')
        if self._fix and tags:
            print('  fixing')
            for tag in tags:
                print(f'  removing {tag}')
                deleted = True
                try:
                    gl_repository.tags.delete(tag)
                except gl_exceptions.GitlabDeleteError:
                    print(f'  unable to delete {tag}')
        return deleted

    def check_container_images(self, gl_project, config):
        """Check that no stale container images exist."""
        if not gl_project.container_registry_enabled or not config.get('container_registry'):
            return

        for gl_repository in gl_project.repositories.list(iterator=True):
            print(f'  processing container images at {gl_repository.name}')
            registries = [rules for registry_pattern, rules in config['container_registry'].items()
                          if re.fullmatch(registry_pattern, gl_repository.name)]
            if len(registries) > 1:
                print(f'  unable to process {gl_repository.name} matching multiple rules')
                continue
            rules = registries[0] if registries else config['container_registry']['default']
            if not rules:
                continue
            # loop until there is nothing deleted anymore
            # at the moment, tags.list() will only return the first 100 tags 🙈
            while self._delete_container_images_once(gl_repository, rules):
                pass

    def run(self, configs, only_group_url=None, only_project_url=None):
        """Run all checkers."""
        print('Discovering groups')
        groups = self.discover_config_groups(configs, only_project_url, only_group_url)
        for group_url, config in groups.items():
            print(f'Processing {group_url}')
            gl_group = self._group(group_url)
            self.check_ldap_group_links(gl_group, config)
            self.check_saml_group_links(gl_group, config)
            for field, expected in config['group'].items():
                self.check_project_or_group_field(gl_group, field, expected)

        print('Discovering projects')
        projects = self.discover_config_projects(configs, only_project_url, only_group_url)
        for project_url, config in projects.items():
            print(f'Processing {project_url}')
            if config['ignore']:
                LOGGER.debug('Ignoring %s', project_url)
                continue

            gl_project = self._project(project_url)
            if gl_project.repository_access_level != 'disabled':
                self.check_protectedtags(gl_project, config)
                self.check_protectedbranches(gl_project, config)
                self.check_tags(gl_project, config)
                self.check_branches(gl_project, config)
                self.check_approvals(gl_project, config)
                self.check_pushrules(gl_project, config)
                self.check_environments(gl_project, config)
                self.check_container_images(gl_project, config)
                self.check_mirror(gl_project, config)
            for field, expected in config['project'].items():
                self.check_project_or_group_field(gl_project, field, expected)


def main(args):
    """Run the main CLI interface."""
    parser = argparse.ArgumentParser()
    parser.add_argument('--fix', action='store_true',
                        help='Fix detected problems')
    parser.add_argument('--project-url',
                        help='Only run the checker for one project')
    parser.add_argument('--group-url',
                        help='Only run the checker for one group')
    parser.add_argument('--config',
                        default=os.environ.get('GITLAB_REPO_CONFIG'),
                        help='Path to the config file')
    parser.add_argument('--config-path',
                        default=os.environ.get('GITLAB_REPO_CONFIG_PATH', 'config.yml'),
                        help='Path to the config file')
    parsed_args = parser.parse_args(args)

    configs = config_tree.process_config_tree(yaml.load(
        contents=parsed_args.config,
        file_path=parsed_args.config_path,
        resolve_references=True,
        resolve_includes=True,
    ))

    RepoConfig(fix=parsed_args.fix).run(configs,
                                        only_group_url=parsed_args.group_url,
                                        only_project_url=parsed_args.project_url)


if __name__ == '__main__':
    misc.sentry_init(sentry_sdk)
    main(sys.argv[1:])
